{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "2d8c6768",
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import Image\n",
    "import os\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "from os.path import join, exists\n",
    "import pickle\n",
    "import time\n",
    "from copy import deepcopy\n",
    "import random\n",
    "\n",
    "import torch\n",
    "from src.open_clip.factory import create_model_and_transforms, get_tokenizer\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "from src.training.precision import get_autocast"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "4570586e",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Please specify the root directory to benchmark datasets \n",
    "ROOT_DATA_DIR = '/PATH/TO/THE/ROOT/DIRECTORY/OF/BENCHMARK/DATASETS'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "ac519782",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Please specify the paths to data frames \n",
    "data_csv_path_dict = {\n",
    "    'SkyScript': '/home/ubuntu/projects/data/benchmark/open/SkyScript_test_30K_filtered_by_CLIP_openai.csv',\n",
    "    'RSICD': '/home/ubuntu/projects/data/benchmark/open/RSICD/RSICD_img_txt_pairs_test.csv',\n",
    "    'RSITMD': '/home/ubuntu/projects/data/benchmark/open/RSITMD/RSITMD_img_txt_pairs_test.csv',\n",
    "    'ucmcaptions': '/home/ubuntu/projects/data/benchmark/open/ucmcaptions/ucmcaptions_img_txt_pairs_test.csv',\n",
    "}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "ad993c64",
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 128\n",
    "precision = 'amp'\n",
    "autocast = get_autocast(precision)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "c5f9a9fb",
   "metadata": {},
   "outputs": [],
   "source": [
    "class CsvDataset_customized(Dataset):\n",
    "    def __init__(self, df, transforms, img_key, caption_key, tokenizer=None, return_img_path=False, \n",
    "                 root_data_dir=None):\n",
    "        if root_data_dir is not None:\n",
    "            df[img_key] = df[img_key].apply(lambda x: join(root_data_dir, x))\n",
    "        self.images = df[img_key].tolist()\n",
    "        self.captions = df[caption_key].tolist()\n",
    "        self.transforms = transforms\n",
    "        self.tokenize = tokenizer\n",
    "        self.return_img_path = return_img_path\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.captions)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        images = self.transforms(Image.open(str(self.images[idx])))\n",
    "        texts = self.tokenize([str(self.captions[idx])])[0]\n",
    "        if self.return_img_path:\n",
    "            return images, texts, str(self.images[idx])\n",
    "        return images, texts\n",
    "    \n",
    "class CsvDataset_image(Dataset):\n",
    "    def __init__(self, df, transforms, img_key, return_img_path=False, root_data_dir=None):\n",
    "        if root_data_dir is not None:\n",
    "            df[img_key] = df[img_key].apply(lambda x: join(root_data_dir, x))\n",
    "        self.images = df[img_key].tolist()\n",
    "        self.transforms = transforms\n",
    "        self.return_img_path = return_img_path\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.images)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        images = self.transforms(Image.open(str(self.images[idx])))\n",
    "        if self.return_img_path:\n",
    "            return images, str(self.images[idx])\n",
    "        return images\n",
    "    \n",
    "class CsvDataset_text(Dataset):\n",
    "    def __init__(self, df, caption_key, tokenizer=None, return_original_text=False, root_data_dir=None):\n",
    "        if root_data_dir is not None:\n",
    "            df[img_key] = df[img_key].apply(lambda x: join(root_data_dir, x))\n",
    "        self.captions = df[caption_key].tolist()\n",
    "        self.tokenize = tokenizer\n",
    "        self.return_original_text = return_original_text\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.captions)\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        original_text = str(self.captions[idx])\n",
    "        texts = self.tokenize([original_text])[0]\n",
    "        if self.return_original_text:\n",
    "            return texts, original_text\n",
    "        return texts\n",
    "\n",
    "def random_seed(seed=42, rank=0):\n",
    "    torch.manual_seed(seed + rank)\n",
    "    np.random.seed(seed + rank)\n",
    "    random.seed(seed + rank)\n",
    "\n",
    "def get_sample_identifier(filepath):\n",
    "    return '/'.join(filepath.split('/')[-2:])\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "d79439ab",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda\n"
     ]
    }
   ],
   "source": [
    "model_arch_name = 'ViT-L-14' # or 'ViT-B-32'\n",
    "ckpt_name = '/THE/PATH/TO/MODEL/CHECKPOINT' # replace this with the path to the .pt file\n",
    "dataset_name = 'ucmcaptions'\n",
    "data_csv_path = data_csv_path_dict[dataset_name]\n",
    "\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "print(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "d29b14a0",
   "metadata": {},
   "outputs": [],
   "source": [
    "if dataset_name == 'SkyScript':\n",
    "    caption_key = 'title_multi_objects'\n",
    "else:\n",
    "    caption_key = 'title'\n",
    "\n",
    "force_quick_gelu = True\n",
    "\n",
    "random_seed(42, 0)\n",
    "if 'ViT-B-32' in model_arch_name:\n",
    "    model, _, preprocess_val = create_model_and_transforms(\n",
    "            model_arch_name,\n",
    "            ckpt_name,\n",
    "            precision=precision,\n",
    "            device=device,\n",
    "            output_dict=True,\n",
    "        )\n",
    "else:\n",
    "    model, _, preprocess_val = create_model_and_transforms(\n",
    "            model_arch_name,\n",
    "            ckpt_name,\n",
    "            precision=precision,\n",
    "            device=device,\n",
    "            output_dict=True,\n",
    "            force_quick_gelu=force_quick_gelu,\n",
    "        )\n",
    "\n",
    "tokenizer = get_tokenizer(model_arch_name)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "9982cbc6",
   "metadata": {},
   "outputs": [],
   "source": [
    "df = pd.read_csv(data_csv_path)\n",
    "df['filepath'] = df['filepath'].apply(lambda x: join(ROOT_DATA_DIR, x))\n",
    "if dataset_name in ['RSICD', 'RSITMD', 'ucmcaptions']:\n",
    "    df[caption_key] = df[caption_key].apply(lambda x: 'a satellite image. ' + x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "37a44088",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((210, 3), (377, 3))"
      ]
     },
     "execution_count": 37,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df_image = df.groupby('filepath').count().reset_index()\n",
    "df_text = df.groupby(caption_key).count().reset_index()\n",
    "df_image.shape, df_text.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "8343a5bd",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████| 256/256 [00:02<00:00, 105.62it/s]\n"
     ]
    }
   ],
   "source": [
    "# unique images\n",
    "dataset_image = CsvDataset_image(\n",
    "    df=df_image, \n",
    "    transforms=preprocess_val,\n",
    "    img_key='filepath',\n",
    "    return_img_path=True,\n",
    "    \n",
    ")\n",
    "\n",
    "dataloader = DataLoader(dataset_image, batch_size=batch_size, shuffle=False, num_workers=4)\n",
    "\n",
    "model.eval()\n",
    "all_image_features = []\n",
    "all_image_paths = []\n",
    "with torch.no_grad():\n",
    "    for batch in tqdm(dataloader, unit_scale=batch_size):\n",
    "        images, img_paths = batch\n",
    "        images = images.to(device=device)\n",
    "        with autocast():\n",
    "            image_features = model.encode_image(images, normalize=True)\n",
    "            all_image_features.append(image_features.cpu())\n",
    "            all_image_paths.extend(img_paths)\n",
    "all_image_features = torch.cat(all_image_features)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "fbf8ed66",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████████| 384/384 [00:00<00:00, 542.15it/s]\n"
     ]
    }
   ],
   "source": [
    "# unique texts\n",
    "dataset_text = CsvDataset_text(\n",
    "    df=df_text, \n",
    "    caption_key=caption_key, \n",
    "    tokenizer=tokenizer, \n",
    "    return_original_text=True,\n",
    ")\n",
    "\n",
    "dataloader = DataLoader(dataset_text, batch_size=batch_size, shuffle=False, num_workers=4)\n",
    "\n",
    "model.eval()\n",
    "all_text_features = []\n",
    "all_texts = []\n",
    "with torch.no_grad():\n",
    "    for batch in tqdm(dataloader, unit_scale=batch_size):\n",
    "        texts, original_texts = batch\n",
    "        texts = texts.to(device=device)\n",
    "        with autocast():\n",
    "            text_features = model.encode_text(texts, normalize=True)\n",
    "            all_text_features.append(text_features.cpu())\n",
    "            all_texts.extend(original_texts)\n",
    "all_text_features = torch.cat(all_text_features)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "7814760c",
   "metadata": {},
   "outputs": [],
   "source": [
    "text_indices = {x: i for i, x in enumerate(all_texts)}\n",
    "img_indices = {x: i for i, x in enumerate(all_image_paths)}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "8c03f909",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|█████████████████████████████████████████████████████████████████████████| 1050/1050 [00:00<00:00, 44819.60it/s]\n"
     ]
    }
   ],
   "source": [
    "# ground truth\n",
    "img_path2text = {}\n",
    "text2img_path = {}\n",
    "for i in tqdm(df.index):\n",
    "    text = df.loc[i, caption_key]\n",
    "    img_path = df.loc[i, 'filepath']\n",
    "    text_id = text_indices[text]\n",
    "    img_id = img_indices[img_path]\n",
    "    if img_path not in img_path2text:\n",
    "        img_path2text[img_path] = set()\n",
    "    img_path2text[img_path].add(text_id)\n",
    "    if text not in text2img_path:\n",
    "        text2img_path[text] = set()\n",
    "    text2img_path[text].add(img_id)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "898c0c95",
   "metadata": {},
   "outputs": [],
   "source": [
    "res = {'text2img_R@' + str(k): 0 for k in [1, 5, 10, 100]}\n",
    "res.update({'img2text_R@' + str(k): 0 for k in [1, 5, 10, 100]})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "554a1f7d",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 377/377 [00:00<00:00, 6980.54it/s]\n"
     ]
    }
   ],
   "source": [
    "# text to image\n",
    "logit_scale = 100\n",
    "for i in tqdm(range(len(all_texts))):\n",
    "    text_feature = all_text_features[i]\n",
    "    logits = logit_scale * text_feature @ all_image_features.t()\n",
    "    ranking = torch.argsort(logits, descending=True).cpu().numpy()\n",
    "    for k in [1, 5, 10, 100]:\n",
    "        intersec = set(ranking[:k]) & set(text2img_path[all_texts[i]])\n",
    "        if intersec:\n",
    "            res['text2img_R@' + str(k)] += 1\n",
    "for k in [1, 5, 10, 100]:\n",
    "    res['text2img_R@' + str(k)] /= len(all_texts)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "4671d8f7",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████| 210/210 [00:00<00:00, 5743.22it/s]\n"
     ]
    }
   ],
   "source": [
    "# image to text\n",
    "logit_scale = 100\n",
    "for i in tqdm(range(len(all_image_paths))):\n",
    "    image_feature = all_image_features[i]\n",
    "    logits = logit_scale * image_feature @ all_text_features.t()\n",
    "    ranking = torch.argsort(logits, descending=True).cpu().numpy()\n",
    "    for k in [1, 5, 10, 100]:\n",
    "        intersec = set(ranking[:k]) & img_path2text[all_image_paths[i]]\n",
    "        if intersec:\n",
    "            res['img2text_R@' + str(k)] += 1\n",
    "for k in [1, 5, 10, 100]:\n",
    "    res['img2text_R@' + str(k)] /= len(all_image_paths)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "id": "ec0e8d3d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'text2img_R@1': 0.3183023872679045,\n",
       " 'text2img_R@5': 0.6419098143236074,\n",
       " 'text2img_R@10': 0.8196286472148541,\n",
       " 'text2img_R@100': 0.9973474801061007,\n",
       " 'img2text_R@1': 0.38571428571428573,\n",
       " 'img2text_R@5': 0.8428571428571429,\n",
       " 'img2text_R@10': 0.9380952380952381,\n",
       " 'img2text_R@100': 1.0}"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "271b5a18",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Environment (conda_pytorch_latest_p37)",
   "language": "python",
   "name": "conda_pytorch_latest_p37"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
